{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "8f2df1fc-db58-4c56-87b0-5899d61846cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from scipy.spatial.distance import cdist\n",
    "from sklearn.preprocessing import normalize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "9e550665-465b-46e5-8f0d-8e5dabaf22c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = pd.read_csv('./汽车领域文本规则泛化性增强挑战赛公开数据/train.csv')\n",
    "train_data['规则id'] = train_data['规则id'].astype(str)\n",
    "\n",
    "test_data = pd.read_csv('./汽车领域文本规则泛化性增强挑战赛公开数据/test_a.csv')\n",
    "\n",
    "rules = pd.read_excel('./汽车领域文本规则泛化性增强挑战赛公开数据/规则数据.xlsx')\n",
    "rules['规则id'] = rules['规则id'].astype(str)\n",
    "\n",
    "train_data = pd.merge(train_data, rules, left_on='规则id', right_on='规则id', how='left')\n",
    "train_data = train_data.sample(frac=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "6330d088-dda7-4fb8-94ae-1e70722bc719",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_examples = []\n",
    "for row in train_data.iterrows():\n",
    "    neg_rules = rules['规则表达式'].sample(5).values\n",
    "    pos_rule = row[1]['规则表达式']\n",
    "\n",
    "    for text in neg_rules:\n",
    "        if text == row[1]['规则表达式']:\n",
    "            continue\n",
    "        train_examples.append(InputExample(texts=[row[1].text, text], label=0.0))\n",
    "\n",
    "    if pos_rule is not np.nan:\n",
    "        train_examples.append(InputExample(texts=[row[1].text, pos_rule], label=1.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "2112591a-4fef-4b9f-a804-130a8cd58677",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30775"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_examples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "0d520c43-4a79-4776-ac95-d24e7a4387a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No sentence-transformers model found with name /home/lyz/下载/m3e-small/. Creating a new one with MEAN pooling.\n"
     ]
    }
   ],
   "source": [
    "from sentence_transformers import SentenceTransformer, models, losses, InputExample, evaluation\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# https://huggingface.co/moka-ai/m3e-small\n",
    "# model = SentenceTransformer('moka-ai/m3e-small')\n",
    "\n",
    "word_embedding_model = models.Transformer('/home/lyz/下载/m3e-small/', max_seq_length=80)\n",
    "\n",
    "tokens = [\"{open}\", \"{close}\", \"{set_up}\", \"{set_down}\", \"{set}\", \"{set_min}\", \"{set_max}\", \"{NUM}\"]\n",
    "word_embedding_model.tokenizer.add_tokens(tokens, special_tokens=True)\n",
    "word_embedding_model.auto_model.resize_token_embeddings(len(word_embedding_model.tokenizer))\n",
    "\n",
    "pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())\n",
    "dense_model = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(), out_features=256, activation_function=nn.Tanh())\n",
    "model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model])\n",
    "\n",
    "model = SentenceTransformer('/home/lyz/下载/m3e-small/')\n",
    "train_dataloader = DataLoader(train_examples[:], shuffle=True, batch_size=16)\n",
    "# val_dataloader = DataLoader(train_examples[-6000:], shuffle=True, batch_size=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "cd7a1856-3b6b-46f2-9f25-0e1675431f1a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "68ce00b0ef764152aa699aa65d9a9eb2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c5c5b7908b424ee0aa0534b4be06f257",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/1924 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "01d21b5d89994897be4afdeed22c844c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/1924 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "316ccaa989534b52bf9673fbb712570b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/1924 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ce0d9e97e5a14e489558459c8d3637cd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/1924 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eaa28a99f3d44ab397334ee91275c3fc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/1924 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_loss = losses.CosineSimilarityLoss(model)\n",
    "model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=5, warmup_steps=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "7dc69c96-0a37-4e55-b3c4-4f061c323fc5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>规则id</th>\n",
       "      <th>规则表达式</th>\n",
       "      <th>语义参考</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>176</th>\n",
       "      <td>打开空调打开冷空调</td>\n",
       "      <td>None</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2255</th>\n",
       "      <td>空调关低点</td>\n",
       "      <td>22</td>\n",
       "      <td>空调{set_down}\\n空调(温度){set_down}\\n温度{set_down}</td>\n",
       "      <td>{\"temperature\":\"minus\"}@set@aircontrol</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2016</th>\n",
       "      <td>空调最冷</td>\n",
       "      <td>14</td>\n",
       "      <td>空调(温度){set_min}\\n空调{set_min}\\n温度{set}{set_min}...</td>\n",
       "      <td>{\"temperature\":\"min\"}@set@aircontrol</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5106</th>\n",
       "      <td>请设置到导航全屏模式</td>\n",
       "      <td>540</td>\n",
       "      <td>({open}|{set})[到]导航全屏[模式]</td>\n",
       "      <td>{\"insType\":\"EXIT_FULL_SCREEN\"}@mapu</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4562</th>\n",
       "      <td>嗯打开音乐</td>\n",
       "      <td>503</td>\n",
       "      <td>{open}(音乐|歌)</td>\n",
       "      <td>{\"instype\":\"open\"}@instruction@musicx</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4850</th>\n",
       "      <td>帮我设置自定义唤醒词</td>\n",
       "      <td>529</td>\n",
       "      <td>(设置|修改)自定义唤醒词</td>\n",
       "      <td>{\"insType\":\"OPEN_CUSTOM_WAKEUP_WORD\"}@cmd</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2461</th>\n",
       "      <td>自动空调</td>\n",
       "      <td>30</td>\n",
       "      <td>空调{set}自动模式\\n空调自动\\n空调{set}自动\\n{open}自动模式</td>\n",
       "      <td>{\"mode\":\"自动\"}@set@aircontrol</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2466</th>\n",
       "      <td>空调调成自动</td>\n",
       "      <td>30</td>\n",
       "      <td>空调{set}自动模式\\n空调自动\\n空调{set}自动\\n{open}自动模式</td>\n",
       "      <td>{\"mode\":\"自动\"}@set@aircontrol</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>411</th>\n",
       "      <td>打开空调制冷一档</td>\n",
       "      <td>None</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2931</th>\n",
       "      <td>打开全部窗户</td>\n",
       "      <td>101</td>\n",
       "      <td>{open}所有车窗\\n全部车窗{open}\\n{open}所有窗户</td>\n",
       "      <td>{\"name\":\"所有窗户\"}@open@carcontrol</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5312 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            text  规则id                                              规则表达式  \\\n",
       "176    打开空调打开冷空调  None                                                NaN   \n",
       "2255       空调关低点    22       空调{set_down}\\n空调(温度){set_down}\\n温度{set_down}   \n",
       "2016        空调最冷    14  空调(温度){set_min}\\n空调{set_min}\\n温度{set}{set_min}...   \n",
       "5106  请设置到导航全屏模式   540                          ({open}|{set})[到]导航全屏[模式]   \n",
       "4562       嗯打开音乐   503                                       {open}(音乐|歌)   \n",
       "...          ...   ...                                                ...   \n",
       "4850  帮我设置自定义唤醒词   529                                      (设置|修改)自定义唤醒词   \n",
       "2461        自动空调    30           空调{set}自动模式\\n空调自动\\n空调{set}自动\\n{open}自动模式   \n",
       "2466      空调调成自动    30           空调{set}自动模式\\n空调自动\\n空调{set}自动\\n{open}自动模式   \n",
       "411     打开空调制冷一档  None                                                NaN   \n",
       "2931      打开全部窗户   101                 {open}所有车窗\\n全部车窗{open}\\n{open}所有窗户   \n",
       "\n",
       "                                           语义参考  \n",
       "176                                         NaN  \n",
       "2255     {\"temperature\":\"minus\"}@set@aircontrol  \n",
       "2016       {\"temperature\":\"min\"}@set@aircontrol  \n",
       "5106        {\"insType\":\"EXIT_FULL_SCREEN\"}@mapu  \n",
       "4562      {\"instype\":\"open\"}@instruction@musicx  \n",
       "...                                         ...  \n",
       "4850  {\"insType\":\"OPEN_CUSTOM_WAKEUP_WORD\"}@cmd  \n",
       "2461               {\"mode\":\"自动\"}@set@aircontrol  \n",
       "2466               {\"mode\":\"自动\"}@set@aircontrol  \n",
       "411                                         NaN  \n",
       "2931            {\"name\":\"所有窗户\"}@open@carcontrol  \n",
       "\n",
       "[5312 rows x 4 columns]"
      ]
     },
     "execution_count": 103,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "53c91c3e-6be5-465d-a375-9c084a7e6297",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_feat = model.encode(list(test_data['text']))\n",
    "test_feat = normalize(test_feat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "id": "f3a7315e-d473-4715-95af-c70d9e5a7798",
   "metadata": {},
   "outputs": [],
   "source": [
    "rules_feat = model.encode(list(rules['规则表达式']))\n",
    "rules_feat = normalize(rules_feat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "id": "9a3f0c98-7e8a-4aa5-9c57-b3b0bca17e55",
   "metadata": {},
   "outputs": [],
   "source": [
    "dist = 1 - cdist(test_feat, rules_feat, metric='cosine')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "id": "b355e3e4-b418-4297-8448-c7cbbd333327",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_rule = []\n",
    "for idx in range(len(dist)):\n",
    "    if dist[idx].max() < 0.8:\n",
    "        test_rule.append(None)\n",
    "    else:\n",
    "        test_rule.append(rules['规则id'].iloc[dist[idx].argmax()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "id": "afab2273-f589-4f57-a6e8-0b1273da3e6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data['规则id'] = test_rule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "id": "a0a5b26a-c272-4da4-8a90-dd911e0f5c71",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data.to_csv('submit.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ab205b7-6c61-4994-ad5b-c92e51639113",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
